{
    "cells": [
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "#  Quantize and speedup any LLM"
            ]
        },
        {
            "cell_type": "raw",
            "metadata": {
                "vscode": {
                    "languageId": "raw"
                }
            },
            "source": [
                "<a target=\"_blank\" href=\"https://colab.research.google.com/github/PrunaAI/pruna/blob/v|version|/docs/tutorials/llm_quantization_compilation_acceleration.ipynb\">\n",
                "    <img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/>\n",
                "</a>"
            ]
        },
        {
            "cell_type": "raw",
            "metadata": {
                "raw_mimetype": "text/restructuredtext",
                "vscode": {
                    "languageId": "raw"
                }
            },
            "source": [
                "This tutorial demonstrates how to use the ``pruna`` package to optimize both the latency and the memory footprint of any LLM from the diffusers package.\n",
                "We will use the ``meta-llama/Llama-3.2-1b-Instruct`` model as an example, but this tutorial is working on any language model.\n",
                "We show here results with ``hqq`` quantizer, but this tutorial is working with ``gptq``, ``llm_int8``, and ``higgs``(need ``pruna_pro`` for this one)."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "# if you are not running the latest version of this tutorial, make sure to install the matching version of pruna\n",
                "# the following command will install the latest version of pruna\n",
                "%pip install pruna"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 1. Loading the LLM\n",
                "\n",
                "First, load your LLM and its associated tokenizer."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import torch\n",
                "from transformers import AutoModelForCausalLM, AutoTokenizer\n",
                "\n",
                "model_id = \"meta-llama/Llama-3.2-1b-Instruct\"\n",
                "\n",
                "# We observed better performance with bfloat16 precision.\n",
                "model = AutoModelForCausalLM.from_pretrained(\n",
                "    model_id,\n",
                "    torch_dtype=torch.bfloat16,\n",
                "    low_cpu_mem_usage=True,\n",
                "    device_map=\"cuda\",\n",
                ")\n",
                "tokenizer = AutoTokenizer.from_pretrained(model_id)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 2. Test the original model speed"
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import time\n",
                "\n",
                "# Warmup the model\n",
                "for _ in range(3):\n",
                "    with torch.no_grad():\n",
                "        inp = tokenizer(\n",
                "            [\"This is a test of this large language model\"], return_tensors=\"pt\"\n",
                "        )\n",
                "        input_ids = inp[\"input_ids\"].cuda()\n",
                "        generated_ids = model.generate(\n",
                "            input_ids,\n",
                "            max_length=input_ids.shape[1] + 56,\n",
                "            min_length=input_ids.shape[1] + 56,\n",
                "        )\n",
                "        text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
                "\n",
                "torch.cuda.synchronize()\n",
                "t = time.time()\n",
                "with torch.no_grad():\n",
                "    inp = tokenizer(\n",
                "        [\"This is a test of this large language model\"], return_tensors=\"pt\"\n",
                "    )\n",
                "    input_ids = inp[\"input_ids\"].cuda()\n",
                "    generated_ids = model.generate(\n",
                "        input_ids,\n",
                "        max_length=input_ids.shape[1] + 56,\n",
                "        min_length=input_ids.shape[1] + 56,\n",
                "    )\n",
                "    text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
                "print(text)\n",
                "torch.cuda.synchronize()\n",
                "print(time.time() - t)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 3. Initializing the Smash Config"
            ]
        },
        {
            "cell_type": "raw",
            "metadata": {
                "raw_mimetype": "text/restructuredtext",
                "vscode": {
                    "languageId": "raw"
                }
            },
            "source": [
                "Next, initialize the smash_config (we make use, here, of the :doc:`hqq-diffusers </compression>` and :doc:`torch-compile </compression>` algorithms)."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from pruna import SmashConfig\n",
                "\n",
                "smash_config = SmashConfig(\n",
                "    {\n",
                "        \"hqq\":\n",
                "            {\n",
                "                # can work with 2, 8 also (but 4 is the best performance)\n",
                "                \"weight_bits\": 4,\n",
                "                # can work with float16, but better performance with bfloat16\n",
                "                \"compute_dtype\": \"torch.bfloat16\"\n",
                "            },\n",
                "        \"torch_compile\": {\n",
                "            \"fullgraph\": True,\n",
                "            \"mode\": \"max-autotune\"\n",
                "            # If the model is not compatible with cudagraphs, you can try to comment the line above\n",
                "            # and uncomment the line below\n",
                "            # \"mode\": \"max-autotune-no-cudagraphs\"\n",
                "            # \"max_kv_cache_size\": 400 # uncomment if you want to use a custom kv cache size\n",
                "            }\n",
                "    }\n",
                ")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 3. Smashing the Model\n",
                "\n",
                "Now, smash the model. This can take up to 30 seconds."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "from pruna import smash\n",
                "\n",
                "# Smash the model\n",
                "pipe = smash(\n",
                "    model=model,\n",
                "    smash_config=smash_config,\n",
                ")"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### 4. Running the Model\n"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "Finally, run the model to generate the text you want.\n",
                "Note we need a small warmup the first time we run it (< 1 minute).\n",
                "\n",
                "NB: Currently the quantized+compiled LLM only support the default sampling strategy, and you need to generate tokens following `model.generate(input_ids, max_new_tokens=X)`, where X is the number of tokens you want to produce. We plan to support other sampling schemes (dola, contrastive, etc.) in the near future."
            ]
        },
        {
            "cell_type": "code",
            "execution_count": null,
            "metadata": {},
            "outputs": [],
            "source": [
                "import time\n",
                "\n",
                "# Warmup the model\n",
                "for _ in range(3):\n",
                "    with torch.no_grad():\n",
                "        inp = tokenizer(\n",
                "            [\"This is a test of this large language model\"], return_tensors=\"pt\"\n",
                "        )\n",
                "        input_ids = inp[\"input_ids\"].cuda()\n",
                "        generated_ids = model.generate(input_ids, max_new_tokens=56)\n",
                "        text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
                "\n",
                "torch.cuda.synchronize()\n",
                "t = time.time()\n",
                "with torch.no_grad():\n",
                "    inp = tokenizer(\n",
                "        [\"This is a test of this large language model\"], return_tensors=\"pt\"\n",
                "    )\n",
                "    input_ids = inp[\"input_ids\"].cuda()\n",
                "    generated_ids = model.generate(input_ids, max_new_tokens=56)\n",
                "    text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)\n",
                "print(text)\n",
                "torch.cuda.synchronize()\n",
                "print(time.time() - t)"
            ]
        },
        {
            "cell_type": "markdown",
            "metadata": {},
            "source": [
                "### Wrap Up"
            ]
        },
        {
            "cell_type": "raw",
            "metadata": {
                "raw_mimetype": "text/restructuredtext",
                "vscode": {
                    "languageId": "raw"
                }
            },
            "source": [
                "Congratulations! You've optimized your LLM using HQQ quantization and TorchCompile!\n",
                "The quantized model uses less memory and runs faster while maintaining good quality.\n",
                "You can try other quantizers from ``pruna`` (``gptq``, ``llm_int8``), or ``higgs`` quantizer from ``pruna_pro`` (this one provides speedups also for batch inference and can maintain quality at low bit levels).\n",
                "\n",
                "Want more optimization techniques? Check out our other tutorials!"
            ]
        }
    ],
    "metadata": {
        "kernelspec": {
            "display_name": "Pruna",
            "language": "python",
            "name": "python3"
        },
        "language_info": {
            "codemirror_mode": {
                "name": "ipython",
                "version": 3
            },
            "file_extension": ".py",
            "mimetype": "text/x-python",
            "name": "python",
            "nbconvert_exporter": "python",
            "pygments_lexer": "ipython3",
            "version": "3.11.13"
        }
    },
    "nbformat": 4,
    "nbformat_minor": 2
}
